import os
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST, CIFAR10
from torchvision import transforms


def get_data(batch_size: int, data_name: str = 'MNIST'):
    """
    获取指定数据集的训练集和测试集。
    
    :param data_name: 数据集名称 ('MNIST' 或 'CIFAR10')
    :return: 训练集和测试集
    """
    os.makedirs('./data/', exist_ok=True)

    if data_name == 'MNIST':
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))  # MNIST 数据标准化
        ])
        train_dataset = MNIST(root='./data/', train=True, download=True, transform=transform)
        test_dataset = MNIST(root='./data/', train=False, download=True, transform=transform)

    elif data_name == 'CIFAR10':
        transform = transforms.Compose([
            transforms.Resize(224),  # 调整到 ResNet 的输入尺寸
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # CIFAR-10 数据标准化
        ])
        train_dataset = CIFAR10(root='./data/', train=True, download=True, transform=transform)
        test_dataset = CIFAR10(root='./data/', train=False, download=True, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

    return train_loader, test_loader


if __name__ == '__main__':
    # 测试 MNIST 数据集
    mnist_train, mnist_test = get_data(32, 'MNIST')

    # 测试 CIFAR-10 数据集
    cifar10_train, cifar10_test = get_data(32, 'CIFAR10')
